# coding = utf-8
# MIT Licensehttp://192.168.0.188:8888/edit/train_whole_A_softmax.py#
# 
# Copyright (c) 2016 David Sandberg
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# 导入需要用到的库
from datetime import datetime
import os.path
import time
import sys
import random
import tensorflow as tf
import numpy as np
import importlib
import argparse          # 命令项选项与参数解析的模块
import facenet
import lfw
import h5py             # h5py文件是存放两类对象的容器，数据集(dataset)和组(group)
import tensorflow.contrib.slim as slim
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from AM_softmax import AM_logits_compute
# from validation_tool import validation
import os
import math
from scipy import misc           # 图像处理


# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def main(args):
    # 载入网络结构
    network = importlib.import_module(args.model_def)
    # 新建或打开保存log的文件夹
    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(log_dir): 
        os.makedirs(log_dir)
    # 新建保存模型的文件夹
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(model_dir): 
        os.makedirs(model_dir)

    np.random.seed(seed=args.seed)  # 使得随机一定
    random.seed(args.seed)
    # 获取数据集，通过get_dataset获取的train_set是包含文件路径和标签的集合
    data_set = facenet.get_dataset(args.data_dir)
    # 把数据集分成训练集和测试集
    train_set, test_set=facenet.split_dataset(data_set, 0, 'SPLIT_CSV')
    if args.filter_filename:          # 过滤数据集的文件名，筛选训练集，保留具有最小数量以上的示例图象的类
        train_set = filter_dataset(train_set, os.path.expanduser(args.filter_filename), 
            args.filter_percentile, args.filter_min_nrof_images_per_class)
    # 获取训练集类的个数(不同的人的数量)
    nrof_classes = len(train_set)
    
    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    pretrained_model = None
    if args.pretrained_model:
        pretrained_model = os.path.expanduser(args.pretrained_model)
        print('Pre-trained model: %s' % pretrained_model)
    # 读取IFW文件夹
    if args.lfw_dir:
        print('LFW directory: %s' % args.lfw_dir)
        # Read the file containing the pairs used for testing
        pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs))
        # Get the paths for the corresponding images
        lfw_paths, actual_issame = lfw.get_paths(os.path.expanduser(args.lfw_dir), pairs, args.lfw_file_ext)
    # 建立图运行上下文
    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False)
        # 获取图片路径和标签
        image_list, label_list = facenet.get_image_paths_and_labels(train_set)
        assert len(image_list) > 0, 'The dataset should not be empty'
        # 将标签列表转成张量
        labels = ops.convert_to_tensor(label_list, dtype=tf.int32)
        range_size = array_ops.shape(labels)[0]
        index_queue = tf.train.range_input_producer(range_size, num_epochs=None,
                             shuffle=True, seed=None, capacity=32)
        index_dequeue_op = index_queue.dequeue_many(args.batch_size*args.epoch_size, 'index_dequeue')
        # 学习率
        learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')
        # 批大小
        batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
        # 用于判断是训练还是测试
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        # 图像路径
        image_paths_placeholder = tf.placeholder(tf.string, shape=(None, 1), name='image_paths')
        # 图像标签
        labels_placeholder = tf.placeholder(tf.int64, shape=(None, 1), name='labels')
        # 输入队列，数据流操作，先入先出
        input_queue = data_flow_ops.FIFOQueue(capacity=256000,
                                    dtypes=[tf.string, tf.int64],
                                    shapes=[(1,), (1,)],
                                    shared_name=None, name=None)
        # enqueue_many返回的是一个操作
        enqueue_op = input_queue.enqueue_many([image_paths_placeholder, labels_placeholder], name='enqueue_op')
        # 预处理（数据加载和扩充）线程的数量
        nrof_preprocess_threads = 4
        images_and_labels = []
        # 在不同的线程中入列不同的tensor需要入列的样本在images_and_labels中
        # 创建的线程个数为len(images_and_labels)
        # 在这里应该是有4个入列线程，因为images_and_labels只append4次
        # 线程i入列张量images_and_labels[i]
        # images_and_labels[i1][j]与images_and_labels[i2][j]的类型和形状必须要一样（当enqueue many参数为true时，第一维可以不一样）
        # batch_join的作用是创建样本批，用于批处理
        # capacity控制着用于增长队列的预取的个数
        # batch_size用于出列的一个批的大小
        # enqueue_many表示一次出列多个数据
        # shapes：样本的shape，默认根据images_and_labels[i]推断出来
        for _ in range(nrof_preprocess_threads):
            filenames, label = input_queue.dequeue()
            images = []
            for filename in tf.unstack(filenames):
                file_contents = tf.read_file(filename)
                image = tf.cast(tf.image.decode_image(file_contents, channels=3), tf.float32)
                # if args.random_crop:
                #     image = tf.random_crop(image, [args.image_size, args.image_size, 3])
                #     # image = tf.image.resize_image_with_crop_or_pad(image, args.image_size, args.image_size)
                # else:
                #     image = tf.image.resize_image_with_crop_or_pad(image, args.image_size, args.image_size)
                if args.random_flip:
                    image = tf.image.random_flip_left_right(image)
                # image = tf.image.random_brightness(image,max_delta=30)
                # image = tf.image.random_contrast(image,lower=0.8,upper=1.2)
                # image = tf.image.random_saturation(image,lower=0.8,upper=1.2)
                image.set_shape((112, 96, 3))
                images.append(tf.subtract(image, 127.5) * 0.0078125)
            images_and_labels.append([images, label])
    
        image_batch, label_batch = tf.train.batch_join(
            images_and_labels, batch_size=batch_size_placeholder, 
            shapes=[(112, 96, 3), ()], enqueue_many=True,
            capacity=4 * nrof_preprocess_threads * args.batch_size,
            allow_smaller_final_batch=True)
        # identity的作用是返回和image_batch一样reshape和内容的tensor
        image_batch = tf.identity(image_batch, 'input')
        label_batch = tf.identity(label_batch, 'label_batch')
        
        print('Total number of classes: %d' % nrof_classes)
        print('Total number of examples: %d' % len(image_list))       
        print('Building training graph')

        # Build the inference graph创建网络层，除了全连接层和损失层
        prelogits, _ = network.inference(image_batch, args.keep_probability, 
            phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size, 
            weight_decay=args.weight_decay)
        
        embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')
        AM_logits = AM_logits_compute(embeddings, label_batch, args, nrof_classes)
        # AM_logits = Arc_logits(embeddings, label_batch, args, nrof_classes)
        # 将指数衰减应用到学习率上
        learning_rate = tf.train.exponential_decay(learning_rate_placeholder, global_step,
            args.learning_rate_decay_epochs*args.epoch_size, args.learning_rate_decay_factor, staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)
        # 计算交叉熵
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=label_batch, logits=AM_logits, name='cross_entropy_per_example')
        cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
       
        # print('test',tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        # calculate the total losses
        for weights in slim.get_variables_by_name('kernel'):
            kernel_regularization = tf.contrib.layers.l2_regularizer(args.weight_decay)(weights)
            print(weights)
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, kernel_regularization)
        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

        if args.weight_decay == 0:
            total_loss = tf.add_n([cross_entropy_mean], name='total_loss')
        else:
            total_loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')
        tf.add_to_collection('losses', total_loss)

        # define two saver in case under 'finetuning on different dataset' situation
        saver_load = tf.train.Saver(tf.trainable_variables(), max_to_keep=1)
        saver_save = tf.train.Saver(tf.trainable_variables(), max_to_keep=1)

        # train_op = facenet.train(total_loss, global_step, args.optimizer,
        #    learning_rate, args.moving_average_decay, tf.trainable_variables(), args.log_histograms)
        # train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss,global_step = global_step,var_list=tf.trainable_variables())
        # 动量优化，saver summary 相关
        train_op = tf.train.MomentumOptimizer(learning_rate, momentum=0.9).minimize(total_loss, global_step=global_step, var_list=tf.trainable_variables())
        summary_op = tf.summary.merge_all()
        
        config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
        # Start running operations on the Graph.
        # 创建Session并进行变量初始化
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        config.gpu_options.allow_growth = True
        # Initialize variables
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        # 运行输入数据队列，并获取FileWriter
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        # 获取线程坐标
        coord = tf.train.Coordinator()
        # 将队列中的所有runner开始执行
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            # 导入 pre-trained model
            if pretrained_model:
                print('Restoring pretrained model: %s' % pretrained_model)
                saver_load.restore(sess, pretrained_model)

            print('Running training')
            epoch = 0
            best_accuracy = 0.0
            
            while epoch < args.max_nrof_epochs:
                # 这里是返回当前的global_step值吗,step可以看做是全局的批处理个数
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                # 构建一个epoch的数据并训练
                train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
                    learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 
                    total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file)

                print('validation running...')
                # 在 lfw 上验证模型性能
                if args.lfw_dir:
                    # best_accuracy = evaluate_double(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, embeddings,
                    #	label_batch, lfw_paths, actual_issame, args.lfw_batch_size, args.lfw_nrof_folds, log_dir, step, summary_writer,best_accuracy, saver_save,model_dir,subdir,image_batch,args)

                    best_accuracy = evaluate(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, embeddings, 
                        label_batch, lfw_paths, actual_issame, args.lfw_batch_size, args.lfw_nrof_folds, log_dir, step, summary_writer, best_accuracy, saver_save, model_dir, subdir)
    return model_dir


def find_threshold(var, percentile):
    hist, bin_edges = np.histogram(var, 100)
    cdf = np.float32(np.cumsum(hist)) / np.sum(hist)
    bin_centers = (bin_edges[:-1]+bin_edges[1:])/2
    threshold = np.interp(percentile*0.01, cdf, bin_centers)
    return threshold


# 定义过滤数据集的函数
def filter_dataset(dataset, data_filename, percentile, min_nrof_images_per_class):  # percentile ：百分位数
    with h5py.File(data_filename, 'r') as f:    # 读取数据名的文件夹
        distance_to_center = np.array(f.get('distance_to_center'))
        label_list = np.array(f.get('label_list'))
        image_list = np.array(f.get('image_list'))
        distance_to_center_threshold = find_threshold(distance_to_center, percentile)
        indices = np.where(distance_to_center >= distance_to_center_threshold)[0]
        filtered_dataset = dataset
        removelist = []
        for i in indices:
            label = label_list[i]
            image = image_list[i]
            if image in filtered_dataset[label].image_paths:
                filtered_dataset[label].image_paths.remove(image)
            if len(filtered_dataset[label].image_paths) < min_nrof_images_per_class:
                removelist.append(label)

        ix = sorted(list(set(removelist)), reverse=True)
        for i in ix:
            del(filtered_dataset[i])

    return filtered_dataset


def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, 
      learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 
      loss, train_op, summary_op, summary_writer, regularization_losses, learning_rate_schedule_file):
    batch_number = 0
    # 获取学习率
    if args.learning_rate > 0.0:
        lr = args.learning_rate
    else:
        lr = facenet.get_learning_rate_from_file(learning_rate_schedule_file, epoch)

    index_epoch = sess.run(index_dequeue_op)
    label_epoch = np.array(label_list)[index_epoch]
    image_epoch = np.array(image_list)[index_epoch]
    
    # Enqueue one epoch of image paths and labels
    labels_array = np.expand_dims(np.array(label_epoch), 1)
    image_paths_array = np.expand_dims(np.array(image_epoch), 1)
    # 将对应的文件路径以及标签入列
    sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array})
    print('training a epoch...')
    # Training loop
    train_time = 0
    while batch_number < args.epoch_size:
        start_time = time.time()
        feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder: True, batch_size_placeholder: args.batch_size}
        # 每100次,将结果保存到log中去
        # 然后运行一次,是一次运行:求loss,根据loss,运行train_op来对参数进行优化
        # 计算REGULARIZATION_LOSSES只是为了打印出来查看正则损失的值,而实际整个训练过程这个值已经包含在total loss中了
        if batch_number % 100 == 0:
            err, _, step, reg_loss, summary_str = sess.run([loss, train_op, global_step, regularization_losses, summary_op], feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, global_step=step)
        else:
            err, _, step, reg_loss = sess.run([loss, train_op, global_step, regularization_losses], feed_dict=feed_dict)
        duration = time.time() - start_time
        print('Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tRegLoss %2.3f' %
              (epoch, batch_number+1, args.epoch_size, duration, err, np.sum(reg_loss)))
        batch_number += 1
        train_time += duration
    # Add validation loss and accuracy to summary
    summary = tf.Summary()
    # pylint: disable=maybe-no-member
    summary.value.add(tag='time/total', simple_value=train_time)
    summary_writer.add_summary(summary, step)
    return step


def evaluate(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, 
        embeddings, labels, image_paths, actual_issame, batch_size, nrof_folds, log_dir, step, summary_writer, best_accuracy, saver_save, model_dir, subdir):
    start_time = time.time()
    # Run forward pass to calculate embeddings
    print('Runnning forward pass on LFW images')
    
    # Enqueue one epoch of image paths and labels
    labels_array = np.expand_dims(np.arange(0, len(image_paths)), 1)
    image_paths_array = np.expand_dims(np.array(image_paths), 1)

    sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array})
    
    embedding_size = embeddings.get_shape()[1]
    print("embeddings", embeddings)
    nrof_images = len(actual_issame)*2
    assert nrof_images % batch_size == 0, 'The number of LFW images must be an integer multiple of the LFW batch size'
    nrof_batches = nrof_images // batch_size
    print("number of batchs is ", nrof_batches)
    emb_array = np.zeros((nrof_images, embedding_size))
    lab_array = np.zeros((nrof_images,))
    for _ in range(nrof_batches):
        feed_dict = {phase_train_placeholder: False, batch_size_placeholder: batch_size}
        emb, lab = sess.run([embeddings, labels], feed_dict=feed_dict)
        lab_array[lab] = lab
        emb_array[lab] = emb

    assert np.array_equal(lab_array, np.arange(nrof_images)) == True, 'Wrong labels used for evaluation, possibly caused by training examples left in the input pipeline'
    _, _, accuracy, val, val_std, far = lfw.evaluate(emb_array, actual_issame, nrof_folds=nrof_folds)
    
    if np.mean(accuracy) > best_accuracy:
        save_variables_and_metagraph(sess, saver_save, summary_writer, model_dir, subdir, step)
        best_accuracy = np.mean(accuracy)

    print('Accuracy: %1.3f+-%1.3f' % (np.mean(accuracy), np.std(accuracy)))
    print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val, val_std, far))
    lfw_time = time.time() - start_time
    # Add validation loss and accuracy to summary
    summary = tf.Summary()
    # pylint: disable=maybe-no-member
    summary.value.add(tag='lfw/accuracy', simple_value=np.mean(accuracy))
    summary.value.add(tag='lfw/val_rate', simple_value=val)
    summary.value.add(tag='time/lfw', simple_value=lfw_time)
    summary_writer.add_summary(summary, step)
    with open(os.path.join(log_dir, 'lfw_result.txt'), 'at') as f:
        f.write('%d\t%.5f\t%.5f\n' % (step, np.mean(accuracy), val))
    return best_accuracy


def load_data(image_paths):
    nrof_samples = len(image_paths)
    images = np.zeros((nrof_samples, 112, 96, 3))
    for i in range(nrof_samples):
        img = misc.imread(image_paths[i])
        img = (img*1.0-127.5)/128
        images[i, :, :, :] = img
    return images


# 计算准确度
def calculate_accuracy(threshold, dist, actual_issame):
    # 比较距离值与阈值的结果得到预测值predict_issame
    predict_issame = np.greater(dist, threshold)
    # 得到true pair，false pair，true negative，false negative，其中fp和fn是错误的预测
    tp = np.sum(np.logical_and(predict_issame, actual_issame))
    fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
    tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
    fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
    # tpr及fpr的计算是为了计算后面的ROC曲线的AUC面积
    tpr = 0 if (tp+fn == 0) else float(tp) / float(tp+fn)
    fpr = 0 if (fp+tn == 0) else float(fp) / float(fp+tn)
    # acc就等于正确的预测除以总的结果对
    acc = float(tp+tn)/dist.size
    return tpr, fpr, acc


def evaluate_with_no_cv(emb_array, actual_issame):
    # thresholds初始化从0到3.99每0.01递增的数组
    thresholds = np.arange(0, 4, 0.01)
    # 从tensorflow中得到的结果，猜测是测试样本对中的第一个图片的特征向量
    embeddings1 = emb_array[0::2]
    # 从tensorflow中得到的结果，猜测是测试样本对中的第二个图片的特征向量
    embeddings2 = emb_array[1::2]

    nrof_thresholds = len(thresholds)
    accuracys = np.zeros(nrof_thresholds)
    # dist是在embeddings1与embeddings2之差之后，进行平方的和
    diff = np.subtract(embeddings1, embeddings2)
    dist = np.sum(np.square(diff), 1)
    # 循环取thresholds中的一个阈值调用calculate_accuracy()计算得到acc_train的数组。并且取准确率最高对应的threshold_index为best_threshold_index
    for threshold_idx, threshold in enumerate(thresholds):
        _, _, accuracys[threshold_idx] = facenet.calculate_accuracy(threshold, dist, actual_issame)

    best_acc = np.max(accuracys)
    best_thre = thresholds[np.argmax(accuracys)]
    return best_acc, best_thre


def evaluate_customize(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, 
        embeddings, labels, image_paths, actual_issame, batch_size, nrof_folds, log_dir, step, summary_writer, best_accuracy, saver_save, model_dir, subdir):
    start_time = time.time()
    # Run forward pass to calculate embeddings
    print('Runnning forward pass on LFW images')
    
    # Enqueue one epoch of image paths and labels
    labels_array = np.expand_dims(np.arange(0, len(image_paths)), 1)
    image_paths_array = np.expand_dims(np.array(image_paths), 1)
    sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array})
    
    embedding_size = embeddings.get_shape()[1]
    nrof_images = len(actual_issame)*2
    assert nrof_images % batch_size == 0, 'The number of LFW images must be an integer multiple of the LFW batch size'
    nrof_batches = nrof_images // batch_size
    emb_array = np.zeros((nrof_images, embedding_size))
    lab_array = np.zeros((nrof_images,))
    for _ in range(nrof_batches):
        feed_dict = {phase_train_placeholder:False, batch_size_placeholder:batch_size}
        emb, lab = sess.run([embeddings, labels], feed_dict=feed_dict)
        # lab_array is used for detecting whether there are some label left in the input pipeline
        lab_array[lab] = lab
        emb_array[lab] = emb

    assert np.array_equal(lab_array, np.arange(nrof_images)) == True, 'Wrong labels used for evaluation, possibly caused by training examples left in the input pipeline'
    accuracy, thre = evaluate_with_no_cv(emb_array, actual_issame)
    
    print('Accuracy: %1.3f, Threshold: %1.3f' % (accuracy, thre))

    lfw_time = time.time() - start_time
    # Add validation loss and accuracy to summary
    summary = tf.Summary()
    # pylint: disable=maybe-no-member
    summary.value.add(tag='lfw/accuracy', simple_value=accuracy)
    summary.value.add(tag='time/lfw', simple_value=lfw_time)
    summary_writer.add_summary(summary, step)


def evaluate_double(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, 
        embeddings, labels, image_paths, actual_issame, batch_size, nrof_folds, log_dir, step, summary_writer, best_accuracy, saver_save, model_dir, subdir, images_placeholder, args):
    start_time = time.time()
    
    # Run forward pass to calculate embeddings
    print('Runnning forward pass on LFW images')
    pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs))
    paths, actual_issame = lfw.get_paths(os.path.expanduser(args.lfw_dir), pairs, args.lfw_file_ext)
    batch_size = args.lfw_batch_size
    nrof_images = len(paths)   # 图片的数量
    nrof_batches = int(math.ceil(1.0*nrof_images / batch_size))
    # math.ceil为向上取整，意味这最后一个batch可能样本数少于batch_size
    emb_array = np.zeros((nrof_images, args.embedding_size))
    for i in range(nrof_batches):
        start_index = i*batch_size
        end_index = min((i+1)*batch_size, nrof_images)  # 保证最后一个batch的正确性
        paths_batch = paths[start_index:end_index]
        images = load_data(paths_batch)
        # by charles
        images_flip = np.flip(images, 2)
        feed_dict = {images_placeholder: images, phase_train_placeholder: False}
        feed_dict_flip = {images_placeholder: images_flip, phase_train_placeholder: False}
        emb = sess.run(embeddings, feed_dict=feed_dict)
        emb_flip = sess.run(embeddings, feed_dict=feed_dict_flip)
        emb_average = (emb + emb_flip)/2.0
        emb_array[start_index:end_index, :] = emb_average
        
    accuracy, thre = evaluate_with_no_cv(emb_array, actual_issame)
    
    if np.mean(accuracy) > best_accuracy:
        save_variables_and_metagraph(sess, saver_save, summary_writer, model_dir, subdir, step)
        best_accuracy = np.mean(accuracy)
    
    print('Accuracy: %1.3f Threshold: %1.3f' % (accuracy, thre))
    
    lfw_time = time.time() - start_time
    # Add validation loss and accuracy to summary
    summary = tf.Summary()
    # pylint: disable=maybe-no-member
    summary.value.add(tag='lfw/accuracy', simple_value=accuracy)
    summary.value.add(tag='time/lfw', simple_value=lfw_time)
    summary_writer.add_summary(summary, step)

    return best_accuracy


# 保存变量和元图
def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step):
    # Save the model checkpoint
    print('Saving variables')
    start_time = time.time()
    checkpoint_path = os.path.join(model_dir, 'model-%s.ckpt' % model_name)
    saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=False)
    save_time_variables = time.time() - start_time
    print('Variables saved in %.2f seconds' % save_time_variables)
    metagraph_filename = os.path.join(model_dir, 'model-%s.meta' % model_name)
    save_time_metagraph = 0  
    if not os.path.exists(metagraph_filename):
        print('Saving metagraph')
        start_time = time.time()
        saver.export_meta_graph(metagraph_filename)
        save_time_metagraph = time.time() - start_time
        print('Metagraph saved in %.2f seconds' % save_time_metagraph)
    summary = tf.Summary()
    # pylint: disable=maybe-no-member
    summary.value.add(tag='time/save_variables', simple_value=save_time_variables)
    summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph)
    summary_writer.add_summary(summary, step)


# 解析参数
def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--logs_base_dir', type=str, 
        help='Directory where to write event logs.', default='./log')
    parser.add_argument('--models_base_dir', type=str,
        help='Directory where to write trained models and checkpoints.', default='./trained_model')
    parser.add_argument('--gpu_memory_fraction', type=float,
        help='Upper bound on the amount of GPU memory that will be used by the process.', default=0.7)
    parser.add_argument('--pretrained_model', type=str,
        help='Load a pretrained model before training starts.')
    parser.add_argument('--data_dir', type=str,
        help='Path to the data directory containing aligned face patches. Multiple directories are separated with colon.')
    parser.add_argument('--model_def', type=str,
        help='Model definition. Points to a module containing the definition of the inference graph.', default='models.resface')
    parser.add_argument('--max_nrof_epochs', type=int,
        help='Number of epochs to run.', default=2)
    parser.add_argument('--batch_size', type=int,
        help='Number of images to process in a batch.', default=100)
    parser.add_argument('--epoch_size', type=int,
        help='Number of batches per epoch.', default=100)
    parser.add_argument('--embedding_size', type=int,
        help='Dimensionality of the embedding.', default=512)
    parser.add_argument('--random_flip', 
        help='Performs random horizontal flipping of training images.', action='store_true')
    parser.add_argument('--keep_probability', type=float,
        help='Keep probability of dropout for the fully connected layer(s).', default=1.0)
    parser.add_argument('--weight_decay', type=float,
        help='L2 weight regularization.', default=0.0)
    parser.add_argument('--learning_rate', type=float,
        help='Initial learning rate. If set to a negative value a learning rate ' +
        'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.1)
    parser.add_argument('--learning_rate_decay_epochs', type=int,
        help='Number of epochs between learning rate decay.', default=100)
    parser.add_argument('--learning_rate_decay_factor', type=float,
        help='Learning rate decay factor.', default=1.0)
    parser.add_argument('--seed', type=int,
        help='Random seed.', default=666)
    parser.add_argument('--nrof_preprocess_threads', type=int,
        help='Number of preprocessing (data loading and augmentation) threads.', default=4)
    parser.add_argument('--learning_rate_schedule_file', type=str,
        help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='data/learning_rate_schedule.txt')
    parser.add_argument('--filter_filename', type=str,
        help='File containing image data used for dataset filtering', default='')
    parser.add_argument('--filter_percentile', type=float,
        help='Keep only the percentile images closed to its class center', default=100.0)
    parser.add_argument('--filter_min_nrof_images_per_class', type=int,
        help='Keep only the classes with this number of examples or more', default=0)
 
    # Parameters for validation on LFW
    parser.add_argument('--lfw_pairs', type=str,
        help='The file containing the pairs to use for validation.', default='data/pairs.csv')
    parser.add_argument('--lfw_file_ext', type=str,
        help='The file extension for the LFW dataset.', default='jpg', choices=['jpg', 'png'])
    parser.add_argument('--lfw_dir', type=str,
        help='Path to the data directory containing aligned face patches.', default='/home/guobin/lfwdataset')
    parser.add_argument('--lfw_batch_size', type=int,
        help='Number of images to process in a batch in the LFW test set.  an integer multiple of the LFW batch size', default=100)
    parser.add_argument('--lfw_nrof_folds', type=int,
        help='Number of folds to use for cross validation. Mainly used for testing.', default=10)
    return parser.parse_args(argv)
  

if __name__ == '__main__':
    main(parse_arguments(sys.argv[1:]))
        
        
        
